#include <mpi.h>
#include "buffer.hpp"
#include "esmd_types.h"
#include "cell.h"
#include "comm.h"
__always_inline static void wait_all(){
}
template <typename... Ts>
__always_inline void wait_all(MPI_Request &req, Ts... rest){
  MPI_Status unused_status;
  MPI_Wait(&req, &unused_status);
  wait_all(rest...);
}
template<typename PackerType>
void pack_brick(pack_buffer &buffer, cellgrid_t *grid, box<int> &brick, PackerType &pack){
  for (int i = brick.lo.x; i < brick.hi.x; i ++){
    for (int j = brick.lo.y; j < brick.hi.y; j ++) {
      for (int k = brick.lo.z; k < brick.hi.z; k ++){
        int offset = get_offset_xyz<true>(grid, i, j, k);
        pack(buffer, offset);
      }
    }
  }
}

template<typename UnpackerType>
void unpack_brick(unpack_buffer &buffer, cellgrid_t *grid, box<int> &brick, UnpackerType &unpack){
  for (int i = brick.lo.x; i < brick.hi.x; i ++){
    for (int j = brick.lo.y; j < brick.hi.y; j ++) {
      for (int k = brick.lo.z; k < brick.hi.z; k ++){
        //printf("%d %d %d\n", i, j, k);
        int offset = get_offset_xyz<true>(grid, i, j, k);
        unpack(buffer, offset);
      }
    }
  }
}

template<typename PackerType, typename UnpackerType>
void forward_comm_template(mpp_t *mpp, cellgrid_t *grid, PackerType &pack, UnpackerType unpack){
  pack_buffer pack_prev(mpp->send_prev);
  pack_buffer pack_next(mpp->send_next);
  unpack_buffer unpack_prev(mpp->recv_prev);
  unpack_buffer unpack_next(mpp->recv_next);
  box<int> local_box(vec<int>{0, 0, 0}, grid->nlocal);

  box<int> z_prev_send = box<int>(vec<int>(0, 0, 0), vec<int>(grid->nlocal.x, grid->nlocal.y, grid->nn));
  box<int> z_next_send = z_prev_send + vec<int>(0, 0, grid->nlocal.z - grid->nn);
  box<int> z_prev_recv = z_prev_send - vec<int>(0, 0, grid->nn);
  box<int> z_next_recv = z_next_send + vec<int>(0, 0, grid->nn);

  MPI_Request recv_req_prev, recv_req_next, send_req_prev, send_req_next;

  MPI_Irecv(mpp->recv_prev, mpp->max_comm_size, MPI_BYTE, mpp->prev.z, stag_next, mpp->comm, &recv_req_prev);
  MPI_Irecv(mpp->recv_next, mpp->max_comm_size, MPI_BYTE, mpp->next.z, stag_prev, mpp->comm, &recv_req_next);

  pack_brick(pack_prev, grid, z_prev_send, pack);
  pack_brick(pack_next, grid, z_next_send, pack);


  MPI_Isend(mpp->send_prev, pack_prev.offset, MPI_BYTE, mpp->prev.z, stag_prev, mpp->comm, &send_req_prev);
  MPI_Isend(mpp->send_next, pack_next.offset, MPI_BYTE, mpp->next.z, stag_next, mpp->comm, &send_req_next);

  wait_all(recv_req_next, recv_req_prev);

  unpack_brick(unpack_next, grid, z_next_recv, unpack);
  unpack_brick(unpack_prev, grid, z_prev_recv, unpack);

  wait_all(send_req_prev, send_req_next);
  //printf("%d %d %d %d\n", pack_prev.offset, pack_next.offset, unpack_prev.offset, unpack_next.offset);  
  pack_prev.offset = 0;
  pack_next.offset = 0;
  unpack_prev.offset = 0;
  unpack_next.offset = 0;
  box<int> y_prev_send = box<int>(vec<int>(0, 0, grid->dim.lo.z), vec<int>(grid->nlocal.x, grid->nn, grid->dim.hi.z));
  box<int> y_next_send = y_prev_send + vec<int>(0, grid->nlocal.y - grid->nn, 0);
  box<int> y_prev_recv = y_prev_send - vec<int>(0, grid->nn, 0);
  box<int> y_next_recv = y_next_send + vec<int>(0, grid->nn, 0);

  MPI_Irecv(mpp->recv_prev, mpp->max_comm_size, MPI_BYTE, mpp->prev.y, stag_next, mpp->comm, &recv_req_prev);
  MPI_Irecv(mpp->recv_next, mpp->max_comm_size, MPI_BYTE, mpp->next.y, stag_prev, mpp->comm, &recv_req_next);

  pack_brick(pack_prev, grid, y_prev_send, pack);
  pack_brick(pack_next, grid, y_next_send, pack);

  MPI_Isend(mpp->send_prev, pack_prev.offset, MPI_BYTE, mpp->prev.y, stag_prev, mpp->comm, &send_req_prev);
  MPI_Isend(mpp->send_next, pack_next.offset, MPI_BYTE, mpp->next.y, stag_next, mpp->comm, &send_req_next);

  wait_all(recv_req_next, recv_req_prev);
  
  unpack_brick(unpack_next, grid, y_next_recv, unpack);
  unpack_brick(unpack_prev, grid, y_prev_recv, unpack);

  wait_all(send_req_prev, send_req_next);

  pack_prev.offset = 0;
  pack_next.offset = 0;
  unpack_prev.offset = 0;
  unpack_next.offset = 0;
  box<int> x_prev_send = box<int>(vec<int>(0, grid->dim.lo.y, grid->dim.lo.z), vec<int>(grid->nn, grid->dim.hi.y, grid->dim.hi.z));
  box<int> x_next_send = x_prev_send + vec<int>(grid->nlocal.x - grid->nn, 0, 0);
  box<int> x_prev_recv = x_prev_send - vec<int>(grid->nn, 0, 0);
  box<int> x_next_recv = x_next_send + vec<int>(grid->nn, 0, 0);
  
  MPI_Irecv(mpp->recv_prev, mpp->max_comm_size, MPI_BYTE, mpp->prev.x, stag_next, mpp->comm, &recv_req_prev);
  MPI_Irecv(mpp->recv_next, mpp->max_comm_size, MPI_BYTE, mpp->next.x, stag_prev, mpp->comm, &recv_req_next);

  pack_brick(pack_prev, grid, x_prev_send, pack);
  pack_brick(pack_next, grid, x_next_send, pack);

  MPI_Isend(mpp->send_prev, pack_prev.offset, MPI_BYTE, mpp->prev.x, stag_prev, mpp->comm, &send_req_prev);
  MPI_Isend(mpp->send_next, pack_next.offset, MPI_BYTE, mpp->next.x, stag_next, mpp->comm, &send_req_next);

  wait_all(recv_req_next, recv_req_prev);
  
  unpack_brick(unpack_next, grid, x_next_recv, unpack);
  unpack_brick(unpack_prev, grid, x_prev_recv, unpack);

  wait_all(send_req_prev, send_req_next);
}

